package main

import (
	"bytes"
	"go/ast"
	"go/format"
	"go/parser"
	"go/token"
	"os"
	"sort"
	"strings"
	"text/template"

	"golang.org/x/tools/go/packages"
)

func check(err error) {
	if err != nil {
		panic(err)
	}
}

type platformtestFuncDeclFinder struct {
	testFuncs []*ast.FuncDecl
}

func isPlatformtestFunc(n *ast.FuncDecl) bool {
	if !n.Name.IsExported() {
		return false
	}
	if n.Recv != nil {
		return false
	}
	if n.Type.Results.NumFields() != 0 {
		return false
	}
	if n.Type.Params.NumFields() != 1 {
		return false
	}
	se, ok := n.Type.Params.List[0].Type.(*ast.StarExpr)
	if !ok {
		return false
	}
	sel, ok := se.X.(*ast.SelectorExpr)
	if !ok {
		return false
	}
	x, ok := sel.X.(*ast.Ident)
	if !ok {
		return false
	}
	if x.Name != "platformtest" || sel.Sel.Name != "Context" {
		return false
	}
	return true
}

func (e *platformtestFuncDeclFinder) Visit(n2 ast.Node) ast.Visitor {
	switch n := n2.(type) {
	case *ast.File:
		return e
	case *ast.FuncDecl:
		if isPlatformtestFunc(n) {
			e.testFuncs = append(e.testFuncs, n)
		}
		return nil
	default:
		return nil
	}
}

func main() {

	// TODO safeguards that prevent us from deleting non-generated generated_cases.go
	os.Remove("generated_cases.go")
	// (no error handling to easily cover the case where the file doesn't exist)

	pkgs, err := packages.Load(
		&packages.Config{
			Mode:  packages.NeedFiles,
			Tests: false,
		},
		os.Args[1],
	)
	check(err)

	if len(pkgs) != 1 {
		panic(pkgs)
	}

	p := pkgs[0]

	var tests []*ast.FuncDecl

	for _, f := range p.GoFiles {
		s := token.NewFileSet()
		a, err := parser.ParseFile(s, f, nil, parser.AllErrors)
		check(err)
		finder := &platformtestFuncDeclFinder{}
		ast.Walk(finder, a)
		tests = append(tests, finder.testFuncs...)
	}

	sort.Slice(tests, func(i, j int) bool {
		return strings.Compare(tests[i].Name.Name, tests[j].Name.Name) < 0
	})

	{
		casesTemplate := `
// Code generated by zrepl tooling; DO NOT EDIT.

package tests

var Cases = []Case {
{{- range . -}}
	{{ .Name }},
{{ end -}}
}

	`
		t, err := template.New("CaseFunc").Parse(casesTemplate)
		check(err)

		var buf bytes.Buffer
		err = t.Execute(&buf, tests)
		check(err)

		formatted, err := format.Source(buf.Bytes())
		check(err)

		err = os.WriteFile("generated_cases.go", formatted, 0664)
		check(err)

	}

}
